# This file contains the g_opt function and simulations in the paper "A Nearly Similar Powerful Test for Mediation", 
# by Kees Jan van Garderen and Noud van Giersbergen, published in the Review of Economics and Statistics.
# It simulates the results for Table 4 in the paper as a CSV file

using Printf, CSV, DataFrames, Random, StableRNGs, Distributions, LinearAlgebra, Interpolations
# The above libraries are required. If they are not installed, please uncomment and execute the following line
# import Pkg; Pkg.add(["Printf", "CSV", "DataFrames", "Random", "StableRNGs", "Distributions", "LinearAlgebra", "Interpolations"])

function gOpt(t::Float64)
	# this returns g-opt
	up=quantile(Normal(),1-0.05/2)
	BP  =[0.0,0.1,0.11,0.12,0.14,0.16,0.18,0.2,0.22,0.24,0.26,0.28,0.3,0.32,0.34,0.36,0.38,0.4,0.42,0.44,0.46,0.48,0.5,0.52,0.54,0.56,0.58,0.6,0.62,0.64,0.66,0.68,0.70,0.72,0.74,0.76,0.78,0.8,0.8200000000000001,0.8400000000000001,0.86,0.88,0.9,0.92,0.94,0.96,0.98,1.0,1.02,1.04,1.06,1.08,1.1,1.12,1.14,1.16,1.18,1.2,1.22,1.24,1.26,1.28,1.30,1.32,1.34,1.36,1.38,1.4,1.42,1.44,1.46,1.48,1.5,1.52,1.54,1.56,1.58,1.6,1.62,1.64,1.66,1.68,1.7,1.72,1.74,1.76,1.78,1.8,1.82,1.84,1.86,1.88,1.9,1.92,1.94,1.96,1.98,2.0,2.02,2.04,2.06,2.08,2.1,2.12,2.14,2.16,2.18,2.2,2.22]
	wopt=[0.0,0.1,0.10442420567430982,0.10884823630492242,0.10905928990731939,0.11671970433002472,0.1362534932652685,0.1562531362049586,0.17625313620476138,0.19625313620438045,0.21625313620398562,0.2362531362035911,0.2562531362031933,0.27625313620280345,0.2962531362024044,0.3162531362020066,0.3362531362015954,0.35625313620117466,0.3762531362007406,0.3962531362003004,0.4162531361998559,0.4362531361994087,0.45625313619894653,0.4762531361984941,0.49625313619802836,0.5162531361975771,0.5362531361970972,0.5562531361966323,0.5762531361961756,0.5962531361956576,0.6162531361951945,0.6362531361947645,0.6562531361943463,0.6762531361939518,0.6962531361935707,0.7162531361932336,0.7362531361928256,0.7562531361924566,0.776253136192112,0.7962531361917707,0.8162531361914103,0.8362531361910628,0.8562531361907504,0.8762531361904211,0.8962531361901945,0.9162531361899082,0.9362531361896855,0.9562531361895058,0.976253136189357,0.9962531361891586,1.016253136188927,1.036253136188834,1.0562531361886456,1.0762531361898517,1.0962531361877554,1.1162531361886663,1.1362531361884074,1.1562531361884316,1.1762531361883997,1.1962531361885487,1.2162531361886058,1.236249169961608,1.2562491699478144,1.2761707540556324,1.2961385967950088,1.3116602729399451,1.311660272939983,1.3116602729400462,1.3116602729400606,1.3116602729401678,1.3276222323522948,1.3476222323522515,1.367622232352178,1.3876222323519896,1.40762223235168,1.4276222323513768,1.447622232351633,1.467622232349834,1.4876222323487585,1.5076222323502166,1.5276222323504027,1.5476222323479396,1.5676222323479245,1.5876222323471056,1.607622232345522,1.6276222323441385,1.6476222323424914,1.6676222323412837,1.6876222323400614,1.7076222323389318,1.7276222323381407,1.7476222323375843,1.7676222323367683,1.7876222323354636,1.8076222323344198,1.8276222323330586,1.8476222323314195,1.8676222323295186,1.8876222323272207,1.907622232325869,1.927622232325158,1.9476222323250938,1.9599639845400576,1.9599639845400576,1.9599639845400576,1.9599639845400576,1.9599639845400576,1.9599639845400576,1.9599639845400576]
	if t>BP[end]
		return up
	else
		g1itp = interpolate((BP,), wopt, Gridded(Linear()))
		return g1itp(t)
	end
end

function OLS(m::Vector,x::Vector,y::Vector)
	# determines and returns the OLS estimates, their standard errors, and the residuals
    n=length(x); ι=ones(n)

    # M=αX+v
    X1=[x ι]; X1X1i=inv(X1'X1);
    a=X1X1i*X1'm; res_v=m-X1*a
    s2v=res_v'*res_v/(n-2); VARa=s2v*X1X1i[diagind(X1X1i)]
    SEa=sqrt.(VARa)

    # Y=βM+γX+u
    X2=[m x ι]; X2X2i=inv(X2'X2);
    b=X2X2i*X2'y; res_u=y-X2*b
    s2u=res_u'*res_u/(n-3); VARb=s2u*X2X2i[diagind(X2X2i)]
    SEb=sqrt.(VARb)
    return a, SEa, res_v, b, SEb, res_u
end

function simul(REP,dist_v,n,θ1,θ2,γ,seed)
	# the main simulation
	up=quantile(Normal(),1-0.05/2)
	@printf("REP=%d\n",REP)
	@printf("n=%d\n",n)
	dist_v_mean =mean(dist_v); dist_v_std=std(dist_v)
	dist_u      = dist_v
	dist_u_mean =mean(dist_u); dist_u_std=std(dist_u)
	println(dist_v)
	SE1=zeros(REP);    SE2=zeros(REP)
	reject=zeros(REP); rejectLR=zeros(REP)

	rng = StableRNG(seed)
	# generate x
	x=rand(rng, Normal(0,1),n) # keep fixed

	for r=1:REP
		# generate data
		v =(rand(rng, dist_v,n).-dist_v_mean)./dist_v_std
		m=θ1*x+v
		u =(rand(rng, dist_u,n).-dist_u_mean)./dist_u_std
		y=θ2*m+γ*x+u
		# estimate equations
		dat=[m x y]
		θhat1,SEθhat1,vHat,θhat2,SEθhat2,uHat = OLS(m, x, y)
		SE1[r]=SEθhat1[1];   SE2[r]=SEθhat2[1]
		Ta=θhat1[1]/SEθhat1[1]; Tb=θhat2[1]/SEθhat2[1]
		t1=min(abs(Ta),abs(Tb)); t2=max(abs(Ta),abs(Tb));
		if t1>gOpt(t2)
			reject[r]=1
		end
		if t1>up
			rejectLR[r]=1
		end
	end
	results=[θ1 θ2 100*mean(reject) 100*mean(rejectLR)]
	@printf("θ1=%f, θ2=%f: RF-gOpt=%f   RF-LR=%f\n",θ1,θ2, results[3],results[4])
	return results
end

# hyperparameters
REP=10^6
nvec=[50,100,250,500]
# model parameters
θ1=0.0
θ2vec=[0.0,0.14,0.39,0.59]
γ=0.0
# distributions
distfilename=["Normal","T-dist(5)","Chi2(3)","logNormal"]
distributions=[Normal(0,1) ,TDist(5)     , Chisq(3)      , LogNormal(0,1)]

# save results distribution,n,θ1,θ2,reject_gOpt,reject_LR
K=length(distributions)*length(nvec)*length(θ2vec)
results=zeros(K,6)

idx=0
for dist=1:length(distributions)
	for nindex=1:length(nvec)
		for θ2index=1:length(θ2vec)
			global idx=idx+1
			res=simul(REP,distributions[dist],nvec[nindex],θ1,θ2vec[θ2index],γ,idx)
			results[idx,:]=[dist nvec[nindex] res]
		end
	end
end

# Save results
header = ["dist","n","theta1","theta2","RF-gOpt","RF-LR"]
filename="Table-4-RF-REP"*string(REP)*".csv"
df=DataFrame(results[:,2:6], :auto)
insertcols!(df, 1, :dist => distfilename[Int.(results[:,1])])
CSV.write(filename,df, header=header)
